package websocket

import (
	"context"
	"errors"
	"gitee.com/zackeus/go-boot/common/constants/net/headers"
	"gitee.com/zackeus/go-boot/tools/httpx/response"
	message2 "gitee.com/zackeus/go-boot/websocket/message"
	"gitee.com/zackeus/goutil"
	"gitee.com/zackeus/goutil/strutil"
	"github.com/gorilla/websocket"
	cmap "github.com/orcaman/concurrent-map/v2"
	"github.com/panjf2000/ants/v2"
	"net/http"
	"time"
)

type (
	BCIterCb func(key string, c *Client) bool

	Option func(e *Engine)

	Engine struct {
		upGrader       *websocket.Upgrader
		clients        cmap.ConcurrentMap[string, *Client]
		pool           *ants.Pool
		pingPeriod     time.Duration
		deadLineWait   time.Duration
		maxMessageSize int64
		generateKey    func(r *http.Request) (string, error)
	}
)

func NewEngine(options ...Option) (*Engine, error) {
	/* 构建阻塞的 goroutine pool */
	pool, err := ants.NewPool(1000, ants.WithNonblocking(false))
	if err != nil {
		return nil, err
	}

	e := &Engine{
		upGrader:       &websocket.Upgrader{},
		clients:        cmap.New[*Client](),
		pool:           pool,
		pingPeriod:     time.Minute,
		deadLineWait:   3 * time.Second,
		maxMessageSize: 8192,
		generateKey: func(r *http.Request) (string, error) {
			return r.Header.Get(headers.SecWebSocketKey), nil
		},
	}

	for _, opt := range options {
		opt(e)
	}
	return e, nil
}

// ServeWs websocket 升级
func (e *Engine) ServeWs(w http.ResponseWriter, r *http.Request) (*Client, error) {
	/* 将http请求升级为websocket */
	conn, err := e.upGrader.Upgrade(w, r, nil)
	if err != nil {
		response.ErrorJson(r.Context(), w, err)
		return nil, err
	}

	/* 生成 ClientKey */
	key, err := e.generateKey(r)
	if err != nil {
		msg := &message2.CloseMessage{Code: message2.CloseInternalServerErr, Value: "ClientCey generation failed"}
		_ = conn.WriteControl(int(message2.TypeClose), msg.Format(), time.Now().Add(e.deadLineWait))
		return nil, err
	}
	if strutil.IsBlank(key) {
		msg := &message2.CloseMessage{Code: message2.CloseInternalServerErr, Value: "invalid client key"}
		_ = conn.WriteControl(int(message2.TypeClose), msg.Format(), time.Now().Add(e.deadLineWait))
		return nil, errors.New("invalid client key")
	}

	ctx, cancel := context.WithCancel(context.Background())
	c := &Client{
		ctx:            ctx,
		cancel:         cancel,
		engine:         e,
		key:            key,
		conn:           conn,
		onPing:         func(client *Client, data string) {},
		onPong:         func(client *Client, data string) {},
		onMessage:      func(client *Client, message *message2.Message) {},
		onConnected:    func(client *Client, isForce bool) {},
		onDisConnected: func(client *Client, code message2.CloseCode, reason string, isForce bool) {},
		onError:        func(client *Client, err error) {},
	}
	return c, nil
}

// WithCheckOrigin Origin 检查(可解决跨域)
func WithCheckOrigin(f func(r *http.Request) bool) Option {
	return func(e *Engine) {
		e.upGrader.CheckOrigin = f
	}
}

// WithPingPeriod ping 周期
func WithPingPeriod(period time.Duration) Option {
	return func(e *Engine) {
		e.pingPeriod = period
	}
}

// WithDeadLineWait 读写超时控制
func WithDeadLineWait(wait time.Duration) Option {
	return func(e *Engine) {
		e.deadLineWait = wait
	}
}

// WithMaxMessageSize 读取的消息的最大大小(以字节为单位)
func WithMaxMessageSize(size int64) Option {
	return func(e *Engine) {
		e.maxMessageSize = size
	}
}

// WithGenerateKey ClientKey 生成
func WithGenerateKey(f func(r *http.Request) (string, error)) Option {
	return func(e *Engine) {
		e.generateKey = f
	}
}

// Register 注册
func (e *Engine) Register(id string, client *Client, options ...ClientOption) {
	isForce := false
	e.clients.Upsert(id, client, func(exist bool, valueInMap *Client, newValue *Client) *Client {
		if exist {
			closeMsg := &message2.CloseMessage{Code: message2.CloseForceUnRegister, Value: "Force UnRegister."}
			isForce = true
			/* 注销 */
			_ = valueInMap.Close(&message2.CloseMessage{Code: message2.CloseForceUnRegister, Value: "Force UnRegister."})
			/* disConnected 事件 */
			e.submitOnFunc(func() { valueInMap.onDisConnected(valueInMap, closeMsg.Code, closeMsg.Value, isForce) })
		}
		return newValue
	})

	/* 启动 */
	client.init(id, options)
	/* 发送连接事件 */
	e.submitOnFunc(func() { client.onConnected(client, isForce) })
}

// UnRegister 注销
func (e *Engine) UnRegister(client *Client, msg ...*message2.CloseMessage) bool {
	return e.clients.RemoveCb(client.ID(), func(key string, v *Client, exists bool) bool {
		if !exists || !goutil.IsEqual(v.ClientKey(), client.ClientKey()) {
			/* 不存在 或 不是要注销的 client 直接退出 */
			return false
		}

		closeMsg := &message2.CloseMessage{Code: message2.CloseNormalClosure, Value: "Normal Close."}
		if msg == nil {
			_ = v.Close(closeMsg)
		} else {
			closeMsg = msg[0]
			_ = v.Close(closeMsg)
		}

		/* disConnected 事件 */
		e.submitOnFunc(func() { v.onDisConnected(v, closeMsg.Code, closeMsg.Value, false) })
		return true
	})
}

// Close 根据 ID 关闭 Client
func (e *Engine) Close(id string, msg ...*message2.CloseMessage) (*Client, bool) {
	client, found := e.clients.Pop(id)
	if found {
		closeMsg := &message2.CloseMessage{Code: message2.CloseNormalClosure, Value: "Normal Close."}
		if msg == nil || msg[0] == nil {
			_ = client.Close(closeMsg)
		} else {
			closeMsg = msg[0]
			_ = client.Close(closeMsg)
		}

		/* disConnected 事件 */
		e.submitOnFunc(func() { client.onDisConnected(client, closeMsg.Code, closeMsg.Value, false) })
	}
	return client, found
}

// Broadcast 消息广播
func (e *Engine) Broadcast(msg *message2.Message) {
	closedClients := make([]*Client, 0)

	e.clients.IterCb(func(key string, v *Client) {
		err := v.Send(msg)

		switch {
		case err == nil:
			break
		case errors.Is(err, websocket.ErrCloseSent):
			/* conn 已 close 下线 */
			closedClients = append(closedClients, v)
			return
		default:
			/* 广播失败 */
			e.submitOnFunc(func() { v.onError(v, err) })
			break
		}
	})

	for _, client := range closedClients {
		/* 注销以关闭的 client */
		e.UnRegister(client)
	}
}

func (e *Engine) BroadcastWithIterator(msg *message2.Message, fn BCIterCb) {
	closedClients := make([]*Client, 0)

	e.clients.IterCb(func(key string, v *Client) {
		if !fn(key, v) {
			return
		}

		err := v.Send(msg)
		switch {
		case err == nil:
			break
		case errors.Is(err, websocket.ErrCloseSent):
			/* conn 已 close 下线 */
			closedClients = append(closedClients, v)
			return
		default:
			/* 广播失败 */
			e.submitOnFunc(func() { v.onError(v, err) })
			break
		}
	})

	for _, client := range closedClients {
		/* 注销以关闭的 client */
		e.UnRegister(client)
	}
}

func (e *Engine) submitOnFunc(f func()) {
	_ = e.pool.Submit(f)
}
